from torch.utils.data import Dataset
import torch
import pytorch_lightning as pl
from itertools import cycle
import random
import os.path as op
import config as cfg
import pickle


from transformers import AutoProcessor
from PIL import Image
import numpy as np
from torch.utils.data import DataLoader

class BalancedBatchSampler:
    def __init__(self, dataset='WM', subject=0):
        eeg_data_dir = op.join(cfg.data_dir, dataset)
        eeg_fname = op.join(eeg_data_dir, f'S{subject}.pkl')
        with open(eeg_fname, 'rb') as f:
            data = pickle.load(f)
        idx = data['train_idx']
        label = data['label']
        data_label = torch.tensor(label)
        data_label = data_label[idx]

        label_to_eegs = []
        for label in range(0, 40):
            idx = torch.where(data_label == label)[0]
            label_to_eegs.append(idx.tolist())
        self.label_to_eegs = label_to_eegs

        self.class_iterators = []
        for eeg_idx in enumerate(label_to_eegs):
            self.class_iterators.append(cycle(eeg_idx))

    def __iter__(self):
        for eeg_idx in self.label_to_eegs:
            random.shuffle(eeg_idx)
        for i, eeg_idx in enumerate(self.label_to_eegs):
            self.class_iterators[i] = cycle(eeg_idx)
        for bi in range(self.__len__()):
            batch = [next(self.class_iterators[i]) for i in range(len(self.label_to_eegs))]
            yield batch

    def __len__(self):
        return max(len(eeg_idx) for eeg_idx in self.label_to_eegs)


class EEGDatasetBase(Dataset):
    # Constructor
    def __init__(self, dataset='WM', subject=0, dataset_type='train'):
        '''
        :param eeg_signals_path:
        :param subject:
        :param dataset_type: 'train' | 'test' | 'val'
        '''
        assert dataset_type in ['train', 'test', 'val']
        eeg_data_dir = op.join(cfg.data_dir, dataset)
        img_feat_fname = op.join(cfg.data_dir, 'image_features_FrozenCLIPEmbedder.pkl')
        eeg_fname = op.join(eeg_data_dir, f'S{subject}.pkl')

        with open(img_feat_fname, 'rb') as f:
            data = pickle.load(f)
        image_feat_dict = data['image_features']
        image_feat_cls = data['image_features_cls']
        image_feat_cls = torch.tensor(image_feat_cls, dtype=torch.float32).to('cuda')

        with open(eeg_fname, 'rb') as f:
            data = pickle.load(f)
        eeg = data['eeg']
        eeg = torch.tensor(eeg)
        label = data['label']
        label = torch.tensor(label)
        image_name = data['image_name']
        image_feat = [image_feat_dict[im] for im in image_name]
        image_feat = torch.tensor(image_feat)

        data_idx = {'train': data['train_idx'], 'val': data['val_idx'], 'test': data['test_idx']}
        idx = data_idx[dataset_type]
        image_name_idx = [image_name[i] for i in idx]
        self.eeg = eeg[idx]
        self.label = label[idx]
        self.image_feat = image_feat[idx]
        self.image_feat_cls = image_feat_cls
        self.image_name = image_name_idx
        # Compute size
        self.size = len(self.eeg)


    # Get size
    def __len__(self):
        return self.size

    # Get item
    def __getitem__(self, i):
        eeg = self.eeg[i]
        label = self.label[i]
        image_feat = self.image_feat[i]
        return eeg, label, image_feat
        # Return
        # return eeg, label


class EEGDataModule(pl.LightningDataModule):
    def __init__(self, dataset='WM', subject=1, batch_size=1):
        super().__init__()
        assert dataset in ['WM', 'SK']
        self.dataset = dataset
        self.batch_size = batch_size
        self.subject = subject


    def setup(self, stage=None):
        self.train_dataset = EEGDatasetBase(dataset=self.dataset,
                                            subject=self.subject,
                                            dataset_type='train')
        self.test_dataset = EEGDatasetBase(dataset=self.dataset,
                                           subject=self.subject,
                                           dataset_type='test')
        self.val_dataset = EEGDatasetBase(dataset=self.dataset,
                                             subject=self.subject,
                                             dataset_type='val')


    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=4)


class EEGDataModuleBalanced(EEGDataModule):
    def __init__(self, dataset, subject=1, batch_size=1):
        super().__init__(dataset, subject, batch_size)
        self.batch_sampler = BalancedBatchSampler(dataset=dataset,
                                                  subject=subject)
    def setup(self, stage=None):
        super().setup()

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_sampler=self.batch_sampler, num_workers=4)


def identity(x):
    return x

class EEGDatasetGen(EEGDatasetBase):
    def __init__(self, dataset='WM', subject=0, dataset_type='train',
                 image_transform=identity,
                 imagenet_path=None):
        super().__init__(dataset, subject, dataset_type)
        self.image_transform = image_transform
        self.imagenet_path = imagenet_path
        self.data_len = 500
        imagenet_path = op.abspath(imagenet_path)
        self.processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
        image_path = [op.join(imagenet_path, f.split('_')[0], f'{f}.JPEG') for f in self.image_name]
        self.image_path = image_path

    # Get size
    def __len__(self):
        return self.size

    # Get item
    def __getitem__(self, i):
        eeg = self.eeg[i]
        label = self.label[i]
        image_raw = Image.open(self.image_path[i]).convert('RGB')
        image = np.array(image_raw) / 255.0
        image_raw = self.processor(images=image_raw, return_tensors="pt")
        image_raw['pixel_values'] = image_raw['pixel_values'].squeeze(0)
        return {'eeg': eeg, 'label': label, 'image': self.image_transform(image), 'image_raw': image_raw}

def create_EEG_dataset(dataset='WM',
            subject=0,
            image_transform=identity,
            imagenet_path=None):
    if not isinstance(image_transform, list):
        image_transform = [image_transform, image_transform]
    dataset_train = EEGDatasetGen(dataset=dataset, subject=subject,
                                  image_transform=image_transform[0],
                                  dataset_type='train', imagenet_path=imagenet_path)
    dataset_val = EEGDatasetGen(dataset=dataset, subject=subject,
                                image_transform=image_transform[1],
                                dataset_type='val', imagenet_path=imagenet_path)
    dataset_test = EEGDatasetGen(dataset=dataset, subject=subject,
                                 image_transform=image_transform[1],
                                 dataset_type='test', imagenet_path=imagenet_path)
    return (dataset_train, dataset_val, dataset_test)


if __name__ == '__main__':
    bs = BalancedBatchSampler(dataset='WM', subject=0)


